/***************************************************************************
 * Copyright (c) Wolf Vollprecht, Johan Mabille and Sylvain Corlay          *
 * Copyright (c) QuantStack                                                 *
 *                                                                          *
 * Distributed under the terms of the BSD 3-Clause License.                 *
 *                                                                          *
 * The full license is in the file LICENSE, distributed with this software. *
 ****************************************************************************/

// This file is generated from test/files/cppy_source/test_lstsq.cppy by
// preprocess.py!

#include <algorithm>

#include "xtensor/xarray.hpp"
#include "xtensor/xfixed.hpp"
#include "xtensor/xnoalias.hpp"
#include "xtensor/xstrided_view.hpp"
#include "xtensor/xtensor.hpp"
#include "xtensor/xview.hpp"

#include "gtest/gtest.h"
#include "xtensor-blas/xlinalg.hpp"

namespace xt
{
    using namespace xt::placeholders;

    /*py
    a = np.random.random((6, 3))
    b = np.ones((6))
    */
    TEST(xtest_extended, lstsq1)
    {
        // py_a
        xarray<double> py_a = {
            {0.3745401188473625, 0.9507143064099162, 0.7319939418114051},
            {0.5986584841970366, 0.1560186404424365, 0.1559945203362026},
            {0.0580836121681995, 0.8661761457749352, 0.6011150117432088},
            {0.7080725777960455, 0.0205844942958024, 0.9699098521619943},
            {0.8324426408004217, 0.2123391106782762, 0.1818249672071006},
            {0.1834045098534338, 0.3042422429595377, 0.5247564316322378}};
        // py_b
        xarray<double> py_b = {1., 1., 1., 1., 1., 1.};
        // py_res0 = np.linalg.lstsq(a, b)[0]
        xarray<double> py_res0 = {0.99525656797683, 0.6379298291900684, 0.416589303565964};
        // py_res1 = np.linalg.lstsq(a, b)[1]
        xarray<double> py_res1 = {0.3378625895661748};
        // py_res2 = np.linalg.lstsq(a, b)[2]
        int py_res2 = 3;
        // py_res3 = np.linalg.lstsq(a, b)[3]
        xarray<double> py_res3 = {2.081504268698353, 1.012756249516551, 0.599044658280111};

        auto xres = xt::linalg::lstsq(py_a, py_b);
        EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
        EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
        EXPECT_EQ(std::get<2>(xres), py_res2);
        EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
    }

    /*py
    a = np.random.random((3, 3))
    b = np.ones((3))
    */
    TEST(xtest_extended, lstsq20)
    {
        // py_a
        xarray<double> py_a = {
            {0.4319450186421158, 0.2912291401980419, 0.6118528947223795},
            {0.1394938606520418, 0.2921446485352182, 0.3663618432936917},
            {0.4560699842170359, 0.7851759613930136, 0.1996737821583597}};
        // py_b
        xarray<double> py_b = {1., 1., 1.};
        // py_res0 = np.linalg.lstsq(a, b)[0]
        xarray<double> py_res0 = {-1.655587220862159, 1.7320451450169407, 1.9787446378934206};
        // py_res1 = np.linalg.lstsq(a, b)[1]
        xarray<double> py_res1 = {};
        // py_res2 = np.linalg.lstsq(a, b)[2]
        int py_res2 = 3;
        // py_res3 = np.linalg.lstsq(a, b)[3]
        xarray<double> py_res3 = {1.2339483753871052, 0.4580824861786693, 0.1291723342275802};

        auto xres = xt::linalg::lstsq(py_a, py_b);

        EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
        EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
        EXPECT_EQ(std::get<2>(xres), py_res2);
        EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
    }

    /*py
    a = np.random.random((3, 3))
    b = np.ones((3, 3))
    */
    TEST(xtest_extended, lstsq21)
    {
        // py_a
        xarray<double> py_a = {
            {0.5142344384136116, 0.5924145688620425, 0.0464504127199977},
            {0.6075448519014384, 0.1705241236872915, 0.0650515929852795},
            {0.9488855372533332, 0.9656320330745594, 0.8083973481164611}};
        // py_b
        xarray<double> py_b = {{1., 1., 1.}, {1., 1., 1.}, {1., 1., 1.}};
        // py_res0 = np.linalg.lstsq(a, b)[0]
        xarray<double> py_res0 = {
            {1.6749237812267237, 1.6749237812267237, 1.6749237812267237},
            {0.3213797243357512, 0.3213797243357512, 0.3213797243357512},
            {-1.1128753832544371, -1.1128753832544371, -1.1128753832544371}};
        // py_res1 = np.linalg.lstsq(a, b)[1]
        xarray<double> py_res1 = {};
        // py_res2 = np.linalg.lstsq(a, b)[2]
        int py_res2 = 3;
        // py_res3 = np.linalg.lstsq(a, b)[3]
        xarray<double> py_res3 = {1.8090476189892228, 0.4005423925178662, 0.2705890168670333};

        auto xres = xt::linalg::lstsq(py_a, py_b);

        EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
        EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
        EXPECT_EQ(std::get<2>(xres), py_res2);
        EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
    }

    /*py
    a = np.random.random((2, 5))
    b = np.ones((2))
    */
    TEST(xtest_extended, lstsq3)
    {
        // py_a
        xarray<double> py_a = {
            {0.3046137691733707, 0.0976721140063839, 0.6842330265121569, 0.4401524937396013, 0.1220382348447788},
            {0.4951769101112702, 0.0343885211152184, 0.9093204020787821, 0.2587799816000169, 0.662522284353982}};
        // py_b
        xarray<double> py_b = {1., 1.};
        // py_res0 = np.linalg.lstsq(a, b)[0]
        xarray<double> py_res0 = {
            0.3137661125421979,
            0.183749537801855,
            0.8404557593671863,
            0.7586648365305537,
            -0.1845363594995904};
        // py_res1 = np.linalg.lstsq(a, b)[1]
        xarray<double> py_res1 = {};
        // py_res2 = np.linalg.lstsq(a, b)[2]
        int py_res2 = 2;
        // py_res3 = np.linalg.lstsq(a, b)[3]
        xarray<double> py_res3 = {1.4931292414997537, 0.3589512974668556};

        auto xres = xt::linalg::lstsq(py_a, py_b);
        EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
        EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
        EXPECT_EQ(std::get<2>(xres), py_res2);
        EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
    }

    /*py
    a = np.random.random((2, 5))
    b = np.ones((2, 10))
    */
    TEST(xtest_extended, lstsq4)
    {
        // py_a
        xarray<double> py_a = {
            {0.311711076089411, 0.5200680211778108, 0.5467102793432796, 0.184854455525527, 0.9695846277645586},
            {0.7751328233611146, 0.9394989415641891, 0.8948273504276488, 0.5978999788110851, 0.9218742350231168}};
        // py_b
        xarray<double> py_b = {
            {1., 1., 1., 1., 1., 1., 1., 1., 1., 1.},
            {1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}};
        // py_res0 = np.linalg.lstsq(a, b)[0]
        xarray<double> py_res0 = {
            {-0.0723929848964098,
             -0.0723929848964098,
             -0.0723929848964098,
             -0.0723929848964098,
             -0.0723929848964098,
             -0.0723929848964098,
             -0.0723929848964098,
             -0.0723929848964098,
             -0.0723929848964098,
             -0.0723929848964098},
            {0.1423971374668718,
             0.1423971374668718,
             0.1423971374668718,
             0.1423971374668718,
             0.1423971374668718,
             0.1423971374668718,
             0.1423971374668718,
             0.1423971374668718,
             0.1423971374668718,
             0.1423971374668718},
            {0.2187317829605842,
             0.2187317829605842,
             0.2187317829605842,
             0.2187317829605842,
             0.2187317829605842,
             0.2187317829605842,
             0.2187317829605842,
             0.2187317829605842,
             0.2187317829605842,
             0.2187317829605842},
            {-0.1457627271119433,
             -0.1457627271119433,
             -0.1457627271119433,
             -0.1457627271119433,
             -0.1457627271119433,
             -0.1457627271119433,
             -0.1457627271119433,
             -0.1457627271119433,
             -0.1457627271119433,
             -0.1457627271119433},
            {0.882719722037499,
             0.882719722037499,
             0.882719722037499,
             0.882719722037499,
             0.882719722037499,
             0.882719722037499,
             0.882719722037499,
             0.882719722037499,
             0.882719722037499,
             0.882719722037499}};
        // py_res1 = np.linalg.lstsq(a, b)[1]
        xarray<double> py_res1 = {};
        // py_res2 = np.linalg.lstsq(a, b)[2]
        int py_res2 = 2;
        // py_res3 = np.linalg.lstsq(a, b)[3]
        xarray<double> py_res3 = {2.23042850951828, 0.3968910268428817};

        auto xres = xt::linalg::lstsq(py_a, py_b);
        EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
        EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
        EXPECT_EQ(std::get<2>(xres), py_res2);
        EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
    }

    /*py
    a = np.random.random((10, 5))
    b = np.ones((10, 20))
    */
    TEST(xtest_extended, lstsq5)
    {
        // py_a
        xarray<double> py_a = {
            {0.0884925020519195, 0.1959828624191452, 0.0452272889105381, 0.3253303307632643, 0.388677289689482},
            {0.2713490317738959, 0.8287375091519293, 0.3567533266935893, 0.2809345096873808, 0.5426960831582485},
            {0.1409242249747626, 0.8021969807540397, 0.0745506436797708, 0.9868869366005173, 0.7722447692966574},
            {0.1987156815341724, 0.0055221171236024, 0.8154614284548342, 0.7068573438476171, 0.7290071680409873},
            {0.7712703466859457, 0.0740446517340904, 0.3584657285442726, 0.1158690595251297, 0.8631034258755935},
            {0.6232981268275579, 0.3308980248526492, 0.0635583502860236, 0.3109823217156622, 0.325183322026747},
            {0.7296061783380641, 0.6375574713552131, 0.8872127425763265, 0.4722149251619493, 0.1195942459383017},
            {0.713244787222995, 0.7607850486168974, 0.5612771975694962, 0.770967179954561, 0.4937955963643907},
            {0.5227328293819941, 0.4275410183585496, 0.0254191267440952, 0.1078914269933045, 0.0314291856867343},
            {0.6364104112637804, 0.3143559810763267, 0.5085706911647028, 0.907566473926093, 0.2492922291488749}};
        // py_b
        xarray<double> py_b = {
            {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.},
            {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.},
            {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.},
            {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.},
            {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.},
            {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.},
            {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.},
            {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.},
            {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.},
            {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}};
        // py_res0 = np.linalg.lstsq(a, b)[0]
        xarray<double> py_res0 = {
            {0.7695214798127482, 0.7695214798127482, 0.7695214798127482, 0.7695214798127482,
             0.7695214798127482, 0.7695214798127482, 0.7695214798127482, 0.7695214798127482,
             0.7695214798127482, 0.7695214798127482, 0.7695214798127482, 0.7695214798127482,
             0.7695214798127482, 0.7695214798127482, 0.7695214798127482, 0.7695214798127482,
             0.7695214798127483, 0.7695214798127483, 0.7695214798127483, 0.7695214798127483},
            {0.3603784058338763, 0.3603784058338763, 0.3603784058338763, 0.3603784058338763,
             0.3603784058338763, 0.3603784058338763, 0.3603784058338763, 0.3603784058338763,
             0.3603784058338763, 0.3603784058338763, 0.3603784058338763, 0.3603784058338763,
             0.3603784058338763, 0.3603784058338763, 0.3603784058338763, 0.3603784058338763,
             0.3603784058338762, 0.3603784058338762, 0.3603784058338763, 0.3603784058338763},
            {-0.0288908468951092, -0.0288908468951092, -0.0288908468951092, -0.0288908468951092,
             -0.0288908468951092, -0.0288908468951092, -0.0288908468951092, -0.0288908468951092,
             -0.0288908468951092, -0.0288908468951092, -0.0288908468951092, -0.0288908468951092,
             -0.0288908468951092, -0.0288908468951092, -0.0288908468951092, -0.0288908468951092,
             -0.0288908468951093, -0.0288908468951093, -0.0288908468951092, -0.0288908468951092},
            {0.2739420182164651, 0.2739420182164651, 0.2739420182164651, 0.2739420182164651,
             0.2739420182164651, 0.2739420182164651, 0.2739420182164651, 0.2739420182164651,
             0.2739420182164651, 0.2739420182164651, 0.2739420182164651, 0.2739420182164651,
             0.2739420182164651, 0.2739420182164651, 0.2739420182164651, 0.2739420182164651,
             0.2739420182164651, 0.2739420182164651, 0.2739420182164652, 0.2739420182164652},
            {0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307,
             0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307,
             0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307,
             0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307,
             0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307}};
        // py_res1 = np.linalg.lstsq(a, b)[1]
        xarray<double> py_res1 = {
            0.6683875034141331, 0.6683875034141331, 0.6683875034141331, 0.6683875034141331,
            0.6683875034141331, 0.6683875034141331, 0.6683875034141331, 0.6683875034141331,
            0.6683875034141331, 0.6683875034141331, 0.6683875034141331, 0.6683875034141331,
            0.6683875034141331, 0.6683875034141331, 0.6683875034141331, 0.6683875034141331,
            0.6683875034141331, 0.6683875034141331, 0.6683875034141331, 0.6683875034141331};
        // py_res2 = np.linalg.lstsq(a, b)[2]
        int py_res2 = 5;
        // py_res3 = np.linalg.lstsq(a, b)[3]
        xarray<double> py_res3 =
            {3.317877520855451, 1.0262463009257718, 0.9696565206896536, 0.8384020117545181, 0.5915006407947916};

        auto xres = xt::linalg::lstsq(py_a, py_b);
        EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
        EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
        EXPECT_EQ(std::get<2>(xres), py_res2);
        EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
    }

    /*py
    a = np.array([[0., 1.]])
    b = np.array([1.])
    */
    TEST(xtest_extended, lstsq6)
    {
        // py_a
        xarray<double> py_a = {{0., 1.}};
        // py_b
        xarray<double> py_b = {1.};
        // py_res0 = np.linalg.lstsq(a, b)[0]
        xarray<double> py_res0 = {0., 1.};
        // py_res1 = np.linalg.lstsq(a, b)[1]
        xarray<double> py_res1 = {};
        // py_res2 = np.linalg.lstsq(a, b)[2]
        int py_res2 = 1;
        // py_res3 = np.linalg.lstsq(a, b)[3]
        xarray<double> py_res3 = {1.};

        auto xres = xt::linalg::lstsq(py_a, py_b);
        EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
        EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
        EXPECT_EQ(std::get<2>(xres), py_res2);
        EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
    }

    /*py
    a = np.array([[1.], [1.]])
    b = np.array([1., 1.])
    */
    TEST(xtest_extended, lstsq7)
    {
        // cannot use "// py_a" due to ambiguous initializer list conversion below
        // xarray<double> py_a = {{1.},
        //                        {1.}};
        xarray<double> py_a = xt::ones<double>({2, 1});
        // py_b
        xarray<double> py_b = {1., 1.};
        // py_res0 = np.linalg.lstsq(a, b)[0]
        xarray<double> py_res0 = {0.9999999999999997};
        // py_res1 = np.linalg.lstsq(a, b)[1]
        xarray<double> py_res1 = {2.2508083912556065e-33};
        // py_res2 = np.linalg.lstsq(a, b)[2]
        int py_res2 = 1;
        // py_res3 = np.linalg.lstsq(a, b)[3]
        xarray<double> py_res3 = {1.4142135623730951};

        auto xres = xt::linalg::lstsq(py_a, py_b);
        EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0));
        EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1));
        EXPECT_EQ(std::get<2>(xres), py_res2);
        EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3));
    }

}  // namespace xt
